{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "import pickle\n",
    "import torch.nn.functional as F\n",
    "import matplotlib as plt\n",
    "import random\n",
    "\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model_v2 import EncoderRNN,DecoderRNN\n",
    "from config import *\n",
    "from pre_process import load_data, get_batch, pack_seqs, get_embedding_matrix\n",
    "USE_CUDA = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "BASE_DIR = '../data/'\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = pickle.load(open(BASE_DIR + 'train_data_d.pkl', 'rb'))\n",
    "word2index = pickle.load(open(BASE_DIR + 'word_index.pkl', 'rb'))\n",
    "tag2index = pickle.load(open(BASE_DIR + 'tag2ix.pkl', 'rb'))\n",
    "ix2tag = pickle.load(open(BASE_DIR + 'ix2tag.pkl', 'rb'))\n",
    "word_vec = pickle.load(open(BASE_DIR+'wv_matrix.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(user_config={}):\n",
    "    BATCH_SIZE = user_config['batch']\n",
    "    alpha = user_config['alpha']\n",
    "    device = user_config['device']\n",
    "    epoch = user_config['epoch']\n",
    "    cur_config = \"b{}_a{}\".format(BATCH_SIZE, alpha)\n",
    "    print('current_config : batch_size {} alpha {} GPU {} epoch {}'.\n",
    "          format(BATCH_SIZE, alpha, device, epoch))\n",
    "\n",
    "\n",
    "    encoder = EncoderRNN(len(word2index)+1, EMBEDDING_DIM, ENCODER_DIM, word_vec)\n",
    "    decoder = DecoderRNN(len(ix2tag)+1, len(ix2tag) //7, DECODER_DIM,BATCH_SIZE)\n",
    "\n",
    "    encoder = encoder.to(device)\n",
    "    decoder = decoder.to(device)\n",
    "\n",
    "    bias = [0.0, 1]\n",
    "    bias.extend((alpha for i in range(1, len(ix2tag))))\n",
    "    tag_criterion = nn.CrossEntropyLoss(ignore_index=0, weight=torch.tensor(bias, device=device))\n",
    "\n",
    "    decoder.init_weights()\n",
    "    enc_optim = optim.RMSprop(encoder.parameters(), lr=0.001, alpha=0.9, eps=1e-7)\n",
    "    dec_optim = optim.RMSprop(decoder.parameters(), lr=0.001, alpha=0.9, eps=1e-7)\n",
    "    \n",
    "    split = round(0.05*len(train_data))\n",
    "    train_set = train_data[:split]\n",
    "    test_set = train_data[split:]\n",
    "    P, R, F1 = 0.0, 0.0, 0.0\n",
    "    \n",
    "    print(\"train_set {} \".format(len(train_set)))\n",
    "    hold_loss = []\n",
    "    for step in range(epoch):\n",
    "        losses = []\n",
    "        for i, batch in enumerate(get_batch(BATCH_SIZE, train_set)):\n",
    "            sents, tags, distances = zip(*batch)\n",
    "            if len(sents) < BATCH_SIZE:\n",
    "                break\n",
    "            x = torch.tensor(sents, device=device)\n",
    "            y = torch.tensor(tags, device=device)\n",
    "            dists = torch.tensor(distances, device=device).float()\n",
    "            x_lens = torch.tensor(pack_seqs(sents), device=device)\n",
    "            masked = x.gt(0).view(-1)\n",
    "            encoder_masking = x.lt(1)\n",
    "\n",
    "            encoder.zero_grad()\n",
    "            decoder.zero_grad()\n",
    "\n",
    "            enc_output, hidden_c = encoder(x, x_lens)\n",
    "            start_decode = torch.tensor([[1]] * BATCH_SIZE, device=device)\n",
    "            tag_score = decoder(start_decode, hidden_c, enc_output, encoder_masking)\n",
    "            pre = torch.argmax(tag_score, 1)\n",
    "            pre_m = masked.view(-1) * pre.byte()\n",
    "            tag_loss = tag_criterion(tag_score, y.view(-1))\n",
    "\n",
    "            loss = tag_loss\n",
    "\n",
    "            loss.backward()\n",
    "            losses.append(float(loss))\n",
    "\n",
    "            torch.nn.utils.clip_grad_norm_(encoder.parameters(), 5.0)\n",
    "            torch.nn.utils.clip_grad_norm_(decoder.parameters(), 5.0)\n",
    "\n",
    "            enc_optim.step()\n",
    "            dec_optim.step()\n",
    "            \n",
    "            print('Epoch: [{}/{}], Step: [{}/{}], Loss: {:.4f}'.format(\n",
    "                step+1, epoch, i, len(train_set)//BATCH_SIZE, np.mean(losses)\n",
    "            ))\n",
    "            losses=[]\n",
    "            if step>2:\n",
    "                hold_loss.append[loss]\n",
    "\n",
    "    print(\"Train Complete!\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "current_config : batch_size 64 alpha 10.0 GPU cpu epoch 20\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "__init__() missing 2 required positional arguments: 'input_size' and 'hidden_size'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-d9ad7d677228>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0muser_config\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'batch'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m64\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'alpha'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m10.0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'device'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m'cpu'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'epoch'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0muser_config\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-4-e9384c79a055>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(user_config)\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m     \u001b[0mencoder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mEncoderRNN\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mword2index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mEMBEDDING_DIM\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mENCODER_DIM\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mword_vec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     12\u001b[0m     \u001b[0mdecoder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDecoderRNN\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mix2tag\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mix2tag\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m//\u001b[0m\u001b[0;36m7\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mDECODER_DIM\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mBATCH_SIZE\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Dev/projects/dev/projects/triplet_g/model_v2.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, input_size, embedding_dim, hidden_dim, word2vec, drop_out, device)\u001b[0m\n\u001b[1;32m     19\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_layers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlstm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLSTM\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mEmbedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0membedding_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.7/site-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    423\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    424\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 425\u001b[0;31m         \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mLSTM\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'LSTM'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    426\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    427\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: __init__() missing 2 required positional arguments: 'input_size' and 'hidden_size'"
     ]
    }
   ],
   "source": [
    "user_config = {'batch':64, 'alpha':10.0, 'device':'cpu', 'epoch':20}\n",
    "train(user_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def accuracy(output, target, topk=(1,)):\n",
    "    \"\"\"Computes the precision@k for the specified values of k\"\"\"\n",
    "    with torch.no_grad():\n",
    "        maxk = max(topk)\n",
    "        batch_size = target.size(0)\n",
    "\n",
    "        _, pred = output.topk(maxk, 1, True, True)\n",
    "        pred = pred.t()\n",
    "        correct = pred.eq(target.view(1, -1).expand_as(pred))\n",
    "\n",
    "        res = []\n",
    "        for k in topk:\n",
    "            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)\n",
    "            res.append(correct_k.mul_(1 / (batch_size*50)))\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = nn.LogSoftmax()\n",
    "loss = nn.NLLLoss()\n",
    "# input is of size N x C = 3 x 5\n",
    "input = torch.tensor([[1., 1. , 1.],[0., 1., 1.],[1.5, 1.0, 1.]])\n",
    "# each element in target has to have 0 <= value < C\n",
    "target = torch.tensor([1, 0, 2])\n",
    "print(m(input))\n",
    "output = loss(m(input), target)\n",
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "word2index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-2.4076, -1.4076, -0.4076],\n",
      "        [-1.4076, -2.4076, -0.4076],\n",
      "        [-1.0986, -1.0986, -1.0986]])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/gaikin/miniconda3/lib/python3.7/site-packages/ipykernel/__main__.py:2: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  from ipykernel import kernelapp as app\n",
      "/home/gaikin/miniconda3/lib/python3.7/site-packages/torch/nn/functional.py:52: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor(19.0761)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.Tensor([[1.0,2.0,3.0], [2.0, 1.0, 3.0], [0.0,0.0,0.0]])\n",
    "print(F.log_softmax(x))\n",
    "y = torch.LongTensor([1, 1, 0])\n",
    "w = torch.Tensor([10.0,5.0,1.0])\n",
    "F.cross_entropy(x,y,w, size_average=False, ignore_index=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "multi-target not supported at /pytorch/aten/src/THNN/generic/ClassNLLCriterion.c:21",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-8-31c6ca336c9e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0my2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcross_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/miniconda3/lib/python3.7/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mcross_entropy\u001b[0;34m(input, target, weight, size_average, ignore_index, reduce, reduction)\u001b[0m\n\u001b[1;32m   1548\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0msize_average\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mreduce\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1549\u001b[0m         \u001b[0mreduction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlegacy_get_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize_average\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1550\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mnll_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlog_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1551\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1552\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.7/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mnll_loss\u001b[0;34m(input, target, weight, size_average, ignore_index, reduce, reduction)\u001b[0m\n\u001b[1;32m   1405\u001b[0m                          .format(input.size(0), target.size(0)))\n\u001b[1;32m   1406\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mdim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1407\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnll_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_enum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1408\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0mdim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1409\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnll_loss2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_enum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: multi-target not supported at /pytorch/aten/src/THNN/generic/ClassNLLCriterion.c:21"
     ]
    }
   ],
   "source": [
    "y2 = torch.tensor([[0,10,0],[0,10,0],[0,0,0]])\n",
    "F.cross_entropy(x,y2, ignore_index=0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:miniconda3]",
   "language": "python",
   "name": "conda-env-miniconda3-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
